# -*- coding:UTF-8 -*-

# author:user
# contact: test@test.com
# datetime:2021/11/9 15:31
# software: PyCharm

"""
文件说明：
    
"""
import json

import numpy as np
from bert_serving.client import BertClient

def encode_standard_question():
    bc = BertClient()
    data = json.load(open("./data/nlu/faq.json", "rt", encoding="utf-8"))
    standard_questions = [each['q'] for each in data]
    print("标准问题总量", len(standard_questions))
    print("开始计算encoder....")
    standard_questions_encoder = bc.encode(standard_questions)
    np.save("./data/standard_questions", standard_questions_encoder)
    standard_questions_encoder_len = np.sqrt(np.sum(standard_questions_encoder * standard_questions_encoder, axis=1))
    np.save("./data/standard_questions_len", standard_questions_encoder_len)


if __name__ == '__main__':
    encode_standard_question()